from isssm.typing import to_glssm
def clip_negative_evals(proposal: GLSSMProposal) -> GLSSMProposal:
z = proposal.z
Omega = proposal.Omega
glssm = to_glssm(proposal)
filtered = kalman(z, glssm)
s = smoothed_signals(filtered, z, glssm)
# clip eigenvalues of final Omega to ensure PSD
evals, evecs = jnp.linalg.eigh(Omega)
if (evals > 0).all():
return proposal
evals_clipped = jnp.where(evals < 1e-8, 0.0, evals)
Omega_new = evecs @ vmap(jnp.diag)(evals_clipped) @ jnp.transpose(evecs, (0, 2, 1))
# adjust z s.t. (z -s) lies in span of cOmega, project to span of cOmega, then substract
cOmega = jnp.linalg.cholesky(Omega_new)
# projection matrix on im(cOmega ) is cOmega @ jnp.linalg.pinv(cOmega), but use more stable lstsq
z_new = s + (cOmega @ jnp.linalg.lstsq(cOmega, z - s)[..., None])[..., 0]
new_proposal = GLSSMProposal(
**proposal,
z=z_new,
Omega=Omega,
)
return new_proposal